Skip to content

Conversation

felipemello1
Copy link
Contributor

Memory freebies
image


I dont think that loss/reward is a good way to check correctness here. But i compared the functions locally and they provide the same output.

image image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 7, 2025
return logprobs

# Convert to fp32 for numerical stability
scaled_logits_fp32 = scaled_logits.float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: what's the dtype for scaled_logits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float becomes torch.float32

@felipemello1
Copy link
Contributor Author

@ebsmothers @Jack-Khuu @joecummings @pbontrager can some of you confirm that i dont need to do the all_gather that was happening in selective_log_softmax? Maybe its necessary if trainer.parallelism.disable_loss_parallel=False? But this is off for all of our configs.

import torch.nn.functional as F


def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also delete this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its used in 3 other places. I will leave it there for now. We will prob need some larger refactoring later to clean up / organize losses

@casteryh
Copy link
Contributor

casteryh commented Oct 8, 2025

just curious, if we do torch.compile on the textbook implementation, do we still need manual optimization like selective_log_softmax?
never mind, just remove the selective_log_softmax all together.


# compile loss
logger.info("Compiling loss")
self.loss = torch.compile(self.loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any circumstance under which this command would fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cant think of one in our scenario, but if/when this happens, we can fix it

logprobs = selective_log_softmax(scaled_logits, input_ids)
return logprobs

# Cast up to fp32 for numerical stability
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would change this to something like "ensure logits are in fp32" b/c they actually could already be in fp32 and no need for "Casting up"

@felipemello1 felipemello1 merged commit 75815e1 into meta-pytorch:main Oct 9, 2025
8 checks passed
@felipemello1 felipemello1 deleted the compile_loss branch October 9, 2025 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants